{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "<h1> Create TensorFlow wide-and-deep model </h1>\n",
    "\n",
    "This notebook illustrates:\n",
    "<ol>\n",
    "<li> Creating a model using the high-level Estimator API \n",
    "</ol>"
   ]
  },

  {
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
 "!sudo chown -R jupyter:jupyter /home/jupyter/training-data-analyst"
   ]
  },
  {

"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
 "# Ensure the right version of Tensorflow is installed.\n",
 "!pip freeze | grep tensorflow==2.1"
 ]
},
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# change these to try this notebook out\n",
    "BUCKET = 'cloud-training-demos-ml'\n",
    "PROJECT = 'cloud-training-demos'\n",
    "REGION = 'us-central1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['BUCKET'] = BUCKET\n",
    "os.environ['PROJECT'] = PROJECT\n",
    "os.environ['REGION'] = REGION"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "if ! gcloud storage ls | grep -q gs://${BUCKET}/; then\n",
    "  gcloud storage buckets create --location=${REGION} gs://${BUCKET}\n",
    "fi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "ls *.csv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "<h2> Create TensorFlow model using TensorFlow's Estimator API </h2>\n",
    "<p>\n",
    "First, write an input_fn to read the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "import shutil\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Determine CSV, label, and key columns\n",
    "CSV_COLUMNS = 'weight_pounds,is_male,mother_age,plurality,gestation_weeks,key'.split(',')\n",
    "LABEL_COLUMN = 'weight_pounds'\n",
    "KEY_COLUMN = 'key'\n",
    "\n",
    "# Set default values for each CSV column\n",
    "DEFAULTS = [[0.0], ['null'], [0.0], ['null'], [0.0], ['nokey']]\n",
    "TRAIN_STEPS = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Create an input function reading a file using the Dataset API\n",
    "# Then provide the results to the Estimator API\n",
    "def read_dataset(filename, mode, batch_size = 512):\n",
    "  def _input_fn():\n",
    "    def decode_csv(value_column):\n",
    "      columns = tf.compat.v1.decode_csv(value_column, record_defaults=DEFAULTS)\n",
    "      features = dict(zip(CSV_COLUMNS, columns))\n",
    "      label = features.pop(LABEL_COLUMN)\n",
    "      return features, label\n",
    "    \n",
    "    # Create list of files that match pattern\n",
    "    file_list = tf.compat.v1.gfile.Glob(filename)\n",
    "\n",
    "    # Create dataset from file list\n",
    "    dataset = (tf.compat.v1.data.TextLineDataset(file_list)  # Read text file\n",
    "                 .map(decode_csv))  # Transform each elem by applying decode_csv fn\n",
    "      \n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        num_epochs = None # indefinitely\n",
    "        dataset = dataset.shuffle(buffer_size=10*batch_size)\n",
    "    else:\n",
    "        num_epochs = 1 # end-of-input after this\n",
    " \n",
    "    dataset = dataset.repeat(num_epochs).batch(batch_size)\n",
    "    return dataset\n",
    "  return _input_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Next, define the feature columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Define feature columns\n",
    "def get_wide_deep():\n",
    "  # Define column types\n",
    "  is_male,mother_age,plurality,gestation_weeks = \\\n",
    "      [\\\n",
    "          tf.feature_column.categorical_column_with_vocabulary_list('is_male', \n",
    "                      ['True', 'False', 'Unknown']),\n",
    "          tf.feature_column.numeric_column('mother_age'),\n",
    "          tf.feature_column.categorical_column_with_vocabulary_list('plurality',\n",
    "                      ['Single(1)', 'Twins(2)', 'Triplets(3)',\n",
    "                       'Quadruplets(4)', 'Quintuplets(5)','Multiple(2+)']),\n",
    "          tf.feature_column.numeric_column('gestation_weeks')\n",
    "      ]\n",
    "\n",
    "  # Discretize\n",
    "  age_buckets = tf.feature_column.bucketized_column(mother_age, \n",
    "                      boundaries=np.arange(15,45,1).tolist())\n",
    "  gestation_buckets = tf.feature_column.bucketized_column(gestation_weeks, \n",
    "                      boundaries=np.arange(17,47,1).tolist())\n",
    "\n",
    "  # Sparse columns are wide, have a linear relationship with the output\n",
    "  wide = [is_male,\n",
    "          plurality,\n",
    "          age_buckets,\n",
    "          gestation_buckets]\n",
    "\n",
    "  # Feature cross all the wide columns and embed into a lower dimension\n",
    "  crossed = tf.feature_column.crossed_column(wide, hash_bucket_size=20000)\n",
    "  embed = tf.feature_column.embedding_column(crossed, 3)\n",
    "\n",
    "  # Continuous columns are deep, have a complex relationship with the output\n",
    "  deep = [mother_age,\n",
    "          gestation_weeks,\n",
    "          embed]\n",
    "  return wide, deep"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "To predict with the TensorFlow model, we also need a serving input function. We will want all the inputs from our user."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Create serving input function to be able to serve predictions later using provided inputs\n",
    "def serving_input_fn():\n",
    "    feature_placeholders = {\n",
    "        'is_male': tf.compat.v1.placeholder(tf.string, [None]),\n",
    "        'mother_age': tf.compat.v1.placeholder(tf.float32, [None]),\n",
    "        'plurality': tf.compat.v1.placeholder(tf.string, [None]),\n",
    "        'gestation_weeks': tf.compat.v1.placeholder(tf.float32, [None])\n",
    "    }\n",
    "    features = {\n",
    "        key: tf.expand_dims(tensor, -1)\n",
    "        for key, tensor in feature_placeholders.items()\n",
    "    }\n",
    "    return tf.estimator.export.ServingInputReceiver(features, feature_placeholders)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Create estimator to train and evaluate\n",
    "def train_and_evaluate(output_dir):\n",
    "  wide, deep = get_wide_deep()\n",
    "  EVAL_INTERVAL = 300\n",
    "  run_config = tf.estimator.RunConfig(save_checkpoints_secs = EVAL_INTERVAL,\n",
    "                                      keep_checkpoint_max = 3)\n",
    "  estimator = tf.estimator.DNNLinearCombinedRegressor(\n",
    "                       model_dir = output_dir,\n",
    "                       linear_feature_columns = wide,\n",
    "                       dnn_feature_columns = deep,\n",
    "                       dnn_hidden_units = [64, 32],\n",
    "                       config = run_config)\n",
    "  train_spec = tf.estimator.TrainSpec(\n",
    "                       input_fn = read_dataset('train.csv', mode = tf.estimator.ModeKeys.TRAIN),\n",
    "                       max_steps = TRAIN_STEPS)\n",
    "  exporter = tf.estimator.LatestExporter('exporter', serving_input_fn)\n",
    "  eval_spec = tf.estimator.EvalSpec(\n",
    "                       input_fn = read_dataset('eval.csv', mode = tf.estimator.ModeKeys.EVAL),\n",
    "                       steps = None,\n",
    "                       start_delay_secs = 60, # start evaluating after N seconds\n",
    "                       throttle_secs = EVAL_INTERVAL,  # evaluate every N seconds\n",
    "                       exporters = exporter)\n",
    "  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Finally, train!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Run the model\n",
    "shutil.rmtree('babyweight_trained', ignore_errors = True) # start fresh each time\n",
    "tf.compat.v1.summary.FileWriterCache.clear()\n",
    "train_and_evaluate('babyweight_trained')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "The exporter directory contains the final model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
